// SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include "dataflow_api.h"
#include "../scatter_common.hpp"
#include "dprint.h"

#include <array>

namespace {

template <int32_t N>
FORCE_INLINE std::array<uint32_t, N> make_strides(const std::array<uint32_t, N>& dims) {
    std::array<uint32_t, N> s{};
    uint32_t acc = 1;
    for (int32_t i = N - 1; i >= 0; --i) {
        s[i] = acc;
        acc *= dims[i];
    }
    return s;
}

template <int32_t N>
FORCE_INLINE bool in_bounds(const std::array<uint32_t, N>& idx, const std::array<uint32_t, N>& dims) {
    for (int32_t i = 0; i < N; ++i) {
        if (idx[i] >= dims[i]) {
            return false;
        }
    }
    return true;
}

template <int32_t N>
FORCE_INLINE bool next_inplace(std::array<uint32_t, N>& idx, const std::array<uint32_t, N>& dims) {
    // last axis fastest
    for (int32_t i = N - 1; i >= 0; --i) {
        if (++idx[i] < dims[i]) {
            return true;  // normal increment without carry
        }
        idx[i] = 0;  // carry and continue
    }
    return false;  // overflow past most significant digit
}

template <int32_t N>
FORCE_INLINE uint32_t to_id(const std::array<uint32_t, N>& idx, const std::array<uint32_t, N>& strides) {
    uint32_t id = 0;
    for (int32_t i = 0; i < static_cast<int32_t>(N); ++i) {
        id += idx[i] * strides[i];
    }
    return id;
}

// Convert linear id -> coordinates (row-major, last axis fastest).
template <int32_t N>
std::array<uint32_t, N> from_id(int32_t id, const std::array<uint32_t, N>& dims) {
    std::array<uint32_t, N> coord{};
    // Go left to right: for [d0, d1, ..., dN-1], last axis fastest
    for (int32_t i = N - 1; i >= 0; --i) {
        coord[i] = id % dims[i];
        id /= dims[i];
    }
    return coord;
}

// this function is supposed to load either a whole stick or part of it (76800 elements)
template <typename AddrGen>
FORCE_INLINE void load_to_cb(
    const uint32_t& cb,
    const AddrGen& addr_gtor,
    const uint32_t& offset_bytes,
    const uint32_t& chunk_size_bytes,
    const uint32_t& stick_id) {
    cb_reserve_back(cb, ONE_PAGE);
    const uint64_t source_noc_address = get_noc_addr(stick_id, addr_gtor);
    const uint32_t l1_write_address = get_write_ptr(cb);

    noc_async_read(source_noc_address + offset_bytes, l1_write_address, chunk_size_bytes);
    noc_async_read_barrier();

    cb_push_back(cb, ONE_PAGE);
}

// copies source stick to destination stick (first phase of scatter)
template <typename number_type>
FORCE_INLINE void copy_input_to_output(
    const uint32_t& input_cb, const uint32_t& output_cb, const uint32_t& input_chunk_size) {
    const uint32_t input_l1_read_addr = get_read_ptr(input_cb);
    const uint32_t output_l1_write_addr = get_write_ptr(output_cb);
    volatile tt_l1_ptr number_type* input_l1_read_ptr =
        reinterpret_cast<volatile tt_l1_ptr number_type*>(input_l1_read_addr);
    volatile tt_l1_ptr number_type* output_l1_write_ptr =
        reinterpret_cast<volatile tt_l1_ptr number_type*>(output_l1_write_addr);
    for (uint32_t index_in_input_chunk = 0; index_in_input_chunk < input_chunk_size; ++index_in_input_chunk) {
        output_l1_write_ptr[index_in_input_chunk] = input_l1_read_ptr[index_in_input_chunk];
    }
}

FORCE_INLINE static float bfloat16_to_float(uint16_t bfloat_val) {
    uint32_t uint32_data = ((uint32_t)bfloat_val) << 16;
    float f;
    std::memcpy(&f, &uint32_data, sizeof(f));
    return f;
}

FORCE_INLINE static uint16_t float_to_bfloat16(float val) {
    union {
        float f;
        uint32_t u;
    } ret;
    ret.f = val;
    return uint16_t(ret.u >> 16);
}

template <typename number_type>
FORCE_INLINE number_type perform_reduction(
    number_type input, number_type source_value, ScatterReductionType scatter_reduction_type, DataFormat data_format) {
    if (data_format == DataFormat::Float16_b) {
        float a = bfloat16_to_float(input);
        float b = bfloat16_to_float(source_value);
        float c;
        switch (scatter_reduction_type) {
            case ScatterReductionType::ADD: {
                c = a + b;
                break;
            }
            case ScatterReductionType::MULTIPLY: {
                c = a * b;
                break;
            }
            case ScatterReductionType::AMAX: {
                c = std::max(a, b);
                break;
            }
            case ScatterReductionType::AMIN: {
                c = std::min(a, b);
                break;
            }
            case ScatterReductionType::INVALID: {
                c = b;
                break;
            }
            default: {
                c = b;
                break;
            }
        }
        return float_to_bfloat16(c);
    } else {
        switch (scatter_reduction_type) {
            case ScatterReductionType::ADD: {
                return input + source_value;
            }
            case ScatterReductionType::MULTIPLY: {
                return input * source_value;
            }
            case ScatterReductionType::AMAX: {
                return std::max(input, source_value);
            }
            case ScatterReductionType::AMIN: {
                return std::min(input, source_value);
            }
            case ScatterReductionType::INVALID: {
                return source_value;
            }
            default: {
                return source_value;
            }
        }
    }
}

// performs scatter on data loaded to cb with load_to_cb
template <typename number_type, typename index_type>
FORCE_INLINE void scatter_along_chunk(
    const uint32_t& input_cb,
    const uint32_t& index_cb,
    const uint32_t& source_cb,
    const uint32_t& output_cb,
    const uint32_t& input_stick_size,
    const uint32_t& input_offset,
    const uint32_t& input_chunk_size,
    const uint32_t& index_chunk_size,
    const ScatterReductionType& scatter_reduction_type = ScatterReductionType::INVALID) {
    const uint32_t input_l1_read_addr = get_read_ptr(input_cb);
    const uint32_t index_l1_read_addr = get_read_ptr(index_cb);
    const uint32_t source_l1_read_addr = get_read_ptr(source_cb);
    const uint32_t output_l1_write_addr = get_write_ptr(output_cb);
    volatile tt_l1_ptr number_type* input_l1_read_ptr =
        reinterpret_cast<volatile tt_l1_ptr number_type*>(input_l1_read_addr);
    volatile tt_l1_ptr index_type* index_l1_read_ptr =
        reinterpret_cast<volatile tt_l1_ptr index_type*>(index_l1_read_addr);
    volatile tt_l1_ptr number_type* source_l1_read_ptr =
        reinterpret_cast<volatile tt_l1_ptr number_type*>(source_l1_read_addr);
    volatile tt_l1_ptr number_type* output_l1_write_ptr =
        reinterpret_cast<volatile tt_l1_ptr number_type*>(output_l1_write_addr);

    // each index from the index chunk is checked whether it points
    // to any of the elements in the current output range (defined by
    // partial stick length and offset)
    for (uint32_t index_in_index_chunk = 0; index_in_index_chunk < index_chunk_size; ++index_in_index_chunk) {
        volatile index_type& index_value = index_l1_read_ptr[index_in_index_chunk];
        if (index_value < input_offset || index_value >= input_offset + input_chunk_size) {
            continue;
        }
        if (index_value >= input_stick_size) {
            continue;
        }
        volatile number_type& source_value = source_l1_read_ptr[index_in_index_chunk];
        const uint32_t& output_index = index_value - input_offset;
        output_l1_write_ptr[output_index] = perform_reduction<number_type>(
            output_l1_write_ptr[output_index], source_value, scatter_reduction_type, get_dataformat(input_cb));
    }
}

}  // namespace

void kernel_main() {
    constexpr auto ctas{get_ctas()};

    const uint32_t input_buffer_address = get_arg_val<uint32_t>(0);
    const uint32_t index_buffer_address = get_arg_val<uint32_t>(1);
    const uint32_t source_buffer_address = get_arg_val<uint32_t>(2);
    const uint32_t start_stick_id = get_arg_val<uint32_t>(3);
    const uint32_t sticks_for_core = get_arg_val<uint32_t>(4);
    // for the outer input/output loop (DRAM accesses per stick: input_row_elem_num / 76800)
    const uint32_t input_and_output_chunk_size = get_arg_val<uint32_t>(5);
    // for the inner index/source loop (DRAM accesses per stick per single input/output loop: index_row_elem_num /
    // 76800)
    const uint32_t index_chunk_size = get_arg_val<uint32_t>(6);
    const uint32_t source_chunk_size = get_arg_val<uint32_t>(7);
    const auto scatter_reduction_type = static_cast<ScatterReductionType>(get_arg_val<uint32_t>(8));

    const auto input_addr_gtor = TensorAccessor(ctas.input_args, input_buffer_address, ctas.input_stick_size_bytes);
    const auto index_addr_gtor = TensorAccessor(ctas.index_args, index_buffer_address, ctas.index_stick_size_bytes);
    const auto source_addr_gtor = TensorAccessor(ctas.source_args, source_buffer_address, ctas.source_stick_size_bytes);

    using input_std_type = std_type_t<get_dataformat(ctas.input_cb)>;
    using index_std_type = std_type_t<get_dataformat(ctas.index_cb)>;

    constexpr uint32_t N = ctas.input_rank - 1;
    // generate 2 stick shape counters
    const auto input_dims{make_shape_array_from_runtime_args<N>(9)};
    const auto index_dims{make_shape_array_from_runtime_args<N>(9 + N)};

    const auto index_strides = make_strides<N>(index_dims);

    std::array<uint32_t, N> coord{from_id<N>(start_stick_id, input_dims)};

    for (uint32_t input_stick_id = start_stick_id; input_stick_id < start_stick_id + sticks_for_core;
         ++input_stick_id) {
        // process input/output chunks sequentially
        for (uint32_t input_offset = 0; input_offset < ctas.input_stick_size;
             input_offset += input_and_output_chunk_size) {
            const uint32_t input_chunk_length =
                std::min(ctas.input_stick_size - input_offset, input_and_output_chunk_size);

            // first phase: copy input data to output
            load_to_cb(
                ctas.input_cb,
                input_addr_gtor,
                input_offset * sizeof(input_std_type),
                input_chunk_length * sizeof(input_std_type),
                input_stick_id);
            cb_wait_front(ctas.input_cb, ONE_PAGE);
            cb_reserve_back(ctas.output_cb, ONE_PAGE);

            copy_input_to_output<input_std_type>(ctas.input_cb, ctas.output_cb, input_chunk_length);

            if (in_bounds<N>(coord, index_dims)) {
                const uint32_t index_stick_id = to_id<N>(coord, index_strides);
                // DPRINT << "INSIDE " << index_stick_id << ENDL();
                // second phase: load index and source data chunk-by-chunk and scatter
                for (uint32_t index_offset = 0, source_offset = 0; index_offset < ctas.index_stick_size;
                     index_offset += index_chunk_size, source_offset += source_chunk_size) {
                    // if stick is chunked, the last chunk is usually smaller
                    const uint32_t index_chunk_length =
                        std::min(ctas.index_stick_size - index_offset, index_chunk_size);
                    const uint32_t source_chunk_length =
                        std::min(ctas.source_stick_size - source_offset, source_chunk_size);

                    load_to_cb(
                        ctas.index_cb,
                        index_addr_gtor,
                        index_offset * sizeof(index_std_type),
                        index_chunk_length * sizeof(index_std_type),
                        index_stick_id);
                    // source tensor is sliced beforehand to match index tensor's dimensions, therefore their stick ids
                    // map 1:1
                    load_to_cb(
                        ctas.source_cb,
                        source_addr_gtor,
                        source_offset * sizeof(input_std_type),
                        source_chunk_length * sizeof(input_std_type),
                        index_stick_id);
                    cb_wait_front(ctas.index_cb, ONE_PAGE);
                    cb_wait_front(ctas.source_cb, ONE_PAGE);
                    scatter_along_chunk<input_std_type, index_std_type>(
                        ctas.input_cb,
                        ctas.index_cb,
                        ctas.source_cb,
                        ctas.output_cb,
                        ctas.input_stick_size,
                        input_offset,
                        input_chunk_length,
                        index_chunk_length,
                        scatter_reduction_type);
                    cb_pop_front(ctas.source_cb, ONE_PAGE);
                    cb_pop_front(ctas.index_cb, ONE_PAGE);
                }
            }

                // third phase: push to the output cb
                cb_push_back(ctas.output_cb, ONE_PAGE);
                cb_pop_front(ctas.input_cb, ONE_PAGE);
        }
        next_inplace<N>(coord, input_dims);
    }
}
